Qual é a maneira mais fácil de transformar o tensor de forma (batch_size, height, width) preenchido com valores n em tensor de forma (batch_size, n, height, width)? Eu criei a solução abaixo, mas parece que há uma maneira mais fácil e rápida de fazer isso def batch_tensor_to_onehot (tnsr, classes): tnsr = tnsr.unsqueeze (1) res = [] para cls no intervalo (classes): res.append ((tnsr == cls) .long ()) return torch.cat (res, dim = 1)
2021-02-20 08:19:40
Você pode usar torch.nn.functional.one_hot. Para o seu caso: a = torch.nn.functional.one_hot (tnsr, num_classes = classes) out = a.permute (0, 3, 1, 2) | Você também pode usar Tensor.scatter_ que evita .permute, mas é indiscutivelmente mais difícil de entender do que o método direto proposto por @Alpha. def batch_tensor_to_onehot (tnsr, classes): result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) resultado de retorno Resultados de benchmarking Fiquei curioso e decidi comparar as três abordagens. Descobri que não parece haver uma diferença relativa significativa entre os métodos propostos em relação ao tamanho, largura ou altura do lote. Principalmente o número de classes era o fator de distinção. Claro, como em qualquer referência, a milhagem pode variar. Os benchmarks foram coletados usando índices aleatórios e usando tamanho de lote, altura, largura = 100. Cada experimento foi repetido 20 vezes com a média sendo relatada. A experiência num_classes = 100 é executada uma vez antes da criação de perfil para aquecimento. Os resultados da CPU mostram que o método original provavelmente era melhor para num_classes menores que cerca de 30, enquanto para GPU a abordagem scatter_ parece ser mais rápida. Testes realizados no Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K O código usado para benchmarking é fornecido abaixo: importar tocha from tqdm import tqdm tempo de importação import matplotlib.pyplot as plt def batch_tensor_to_onehot_slavka (tnsr, classes): tnsr = tnsr.unsqueeze (1) res = [] para cls no intervalo (classes): res.append ((tnsr == cls) .long ()) return torch.cat (res, dim = 1) def batch_tensor_to_onehot_alpha (tnsr, classes): result = torch.nn.functional.one_hot (tnsr, num_classes = classes) retornar result.permute (0, 3, 1, 2) def batch_tensor_to_onehot_jodag (tnsr, classes): result = torch.zeros (tnsr.shape [0], classes, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device) result.scatter_ (1, tnsr.unsqueeze (1), 1) resultado de retorno def main (): num_classes = [2, 10, 25, 50, 100] altura = 100 largura = 100 bs = [100] * 20 para d em ['cpu', 'cuda']: times_slavka = [] times_alpha = [] times_jodag = [] aquecimento = verdadeiro para c em tqdm ([num_classes [-1]] + num_classes, ncols = 0): tslavka = 0 talpha = 0 tjodag = 0 para b em bs: tnsr = torch.randint (c, (b, altura, largura)). to (dispositivo = d) t0 = time.time () y = batch_tensor_to_onehot_slavka (tnsr, c) torch.cuda.synchronize () tslavka + = time.time () - t0 se não aquecimento: times_slavka.append (tslavka / len (bs)) para b em bs: tnsr = torch.randint (c, (b, altura, largura)). to (dispositivo = d) t0 = time.time () y = batch_tensor_to_onehot_alpha (tnsr, c) torch.cuda.synchronize () talpha + = time.time () - t0 se não aquecimento: times_alpha.append (talpha / len (bs)) para b em bs: tnsr = torch.randint (c, (b, altura, largura)). to (dispositivo = d) t0 = time.time () y = batch_tensor_to_onehot_jodag (tnsr, c) torch.cuda.synchronize () tjodag + = time.time () - t0 se não aquecimento: times_jodag.append (tjodag / len (bs)) aquecimento = falso fig = plt.figure () ax = fig.subplots () ax.plot (num_classes, times_slavka, label = 'Slavka-cat') ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot') ax.plot (num_classes, times_jodag, label = 'jodag-scatter_') ax.set_xlabel ('num_classes') ax.set_ylabel ('hora (s)') ax.set_title (f '{d} benchmark') ax.legend () plt.savefig (f '{d} .png') plt.show () if __name__ == "__main__": a Principal() | sua resposta StackExchange.ifUsing ("editor", function () { StackExchange.using ("externalEditor", function () { StackExchange.using ("snippets", function () { StackExchange.snippets.init (); }); }); }, "partes de codigo"); StackExchange.ready (function () { var channelOptions = { tags: "" .split (""), id: "1" }; initTagRenderer ("". split (""), "" .split (""), channelOptions); StackExchange.using ("externalEditor", function () { // Tem que disparar o editor após os snippets, se os snippets estiverem habilitados if (StackExchange.settings.snippets.snippetsEnabled) { StackExchange.using ("snippets", function () { createEditor (); }); } outro { createEditor (); } }); function createEditor () { StackExchange.prepareEditor ({ useStacksEditor: false, heartbeatType: 'answer', autoActivateHeartbeat: false, convertImagesToLinks: true, noModals: true, showLowRepImageUploadWarning: true, reputaçãoToPostImages: 10, bindNavPrevention: true, postfix: "", imageUploader: { brandingHtml: "Powered by \ u003ca href = \" https: //imgur.com/ \ "\ u003e \ u003csvg class = \" svg-icon \ "width = \" 50 \ "height = \" 18 \ "viewBox = \ "0 0 50 18 \" fill = \ "none \" xmlns = \ "http: //www.w3.org/2000/svg \" \ u003e \ u003cpath d = \ "M46.1709 9.17788C46.1709 8.26454 46,2665 7,94324 47,1084 7,58816C47.4091 7,46349 47,7169 7,36433 48,0099 7,26993C48.9099 6,97997 49,672 6,73443 49,672 5,93063C49,672 5,22043 48,9832 4,61182 48,1414 4,61182C47,4335 4,26993C48.9099 6,97997 49,672 6,73443 49,672 5,93063C49,672 5,22043 48,9832 4,61182 48,1414 4,61182C47,4335 4,6431182 4,25.654,23,7623,74,6281 4,9281 467823,621 4,96281 4,9823,623,623,461,45,023,74,6281 4,9823,1623,623,14,6281 4,9282 C 4,9281 4,923,623,623,76,2 53,76,23,76,823,14,65,0281 4,928,023,623,14,65,0281 4,928,023,623. 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d = \" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.532327 4.669.4394398.439.532.532.54.3394.339.049.0485.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.532.5327 4.669.04398.439.04928. 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 32.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.662324.532.34.532.932c 32.669.46.532e 32.662.334. fill-rule = \ "evenodd \" clip-rule = \ "evenodd \" d = \ "M25.6622 17.6335C27.8049 17.6335 29.3739 16,9402 30,2537 15.6379C30.8468 14.7755 30,9615 13.5579 30,9615 11,9512V6.590494.4335 29.3739 16,9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11,9512V6.59049C30.9615 5,28821 29.4502 4.66231C28.9913 4.66231 28.4555 4.94978 28.1109 5.50789C27.499 4.86533 26.7335 4.56087 25.7005 4.56087C23.1369 4.56087 21.0134 6.57349 21.0134 9.27932C21.0134 11.9852.423.003 13.913 25.3754.13.136. C28. 1256 12.8854 28.1301 12.9342 28.1301 12.983C28.1301 14.4373 27.2502 15.2321 25,777 15.2321C24.8349 15.2321 24.1352 14.9821 23.5661 14.7787C23.176 14.6393 22,8472 14.5218 22.5437 14.52.125.527.217.2321C24.8349 15.2321 24.1352 14.9821 23.5661 14.7787C23.176 14.6393 22.8472 14.5218 22.5437 14.52.18.02.17.217.217.217.217.217.217.217.217.217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217.217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,217,2 133,2172 c. C24.1317 7,94324 24,9928 7,09766 26,1024 7,09766C27,2119 7,09766 28,0918 7,94324 28,0918 9,27932C28,0918 10,6321 27,2311 11,5116 26,1024 11,5116C24,9737 11,5116 24,1317 10,6491 24,1317 uc. 8045 13,2535 17,2637 13,8962 18,2965 13,8962C19,3298 13,8962 19,8079 13,2535 19,8079 11,9512V8.12928C19,8079 5,82936 18,4879 4,62866 16,7027 4,62866C15,1594 4,62866 14,279 4,98375 13,36099 19,8079 11,9512V8.12928C19,8079 5,82936 18,4879 4,62866 16,4027 4,62866C15,1594 4,62866 14,279 4,98375 13,36099 19,8079 11,9512V8.12928C19,8079 5,82936 18,4879 4,62866 16,4027 4,62866C15,1594 4,62866 14,279 4,98375 13,36099 9,88013C12.656 C 5,6286 4,6286,36,35,35,6286 4,6286C 58314 4,9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8915.913.43.915.915.915.913.915.915.915.915.913.915.913.915.913.512.513.915.913.915.913.512.513.915.913.512. C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d = \" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0,39013354 5.93512 0,3513.513.513.513.512. 1.82471 13.8962C2.85798 13.8962 3.31675 13.2535 3.31675 11.9512V6.59049Z \ "/ \ u003e \ u003cpath d = \" M1.87209 0.400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0,8228.923 3.576 3.561 C0 2.87869 0.8228.923 3.576 3.576 1.87209 3.776 3.8228.946 3.576 3.561 C7209 3.776 3.17823 1.978.923 3.576 1.87209 0,400291C0.843612 0.400291 0 1.1159 0 1.98861C0 2.87869 0.8228.923 3.576 3.561 c. C3.7234 1,1159 2,90056 0,400291 1,87209 0,400291Z \ "fill = \" # 1BB76E \ "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e", contentPolicyHtml: "Contribuições do usuário licenciadas sob \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (política de conteúdo) \ u003c / a \ u003e", allowUrls: true }, onDemand: true, discardSelector: ".discard-answer" , imediatamenteShowMarkdownHelp: true, enableTables: true, enableSnippets: true }); } }); Obrigado por contribuir com uma resposta para Stack Overflow! Certifique-se de responder à pergunta. Forneça detalhes e compartilhe sua pesquisa! Mas evite ... Pedir ajuda, esclarecimento ou responder a outras respostas. Fazer declarações com base em opinião; apoie-os com referências ou experiência pessoal. Para saber mais, veja nossas dicas sobre como escrever boas respostas. Rascunho salvo Rascunho descartado Cadastre-se ou faça o login StackExchange.ready (function () { StackExchange.helpers.onClickDraftSave ('# login-link'); }); Inscreva-se usando o Google Cadastre-se usando o Facebook Inscreva-se usando e-mail e senha Enviar Postar como convidado Nome O email Obrigatório, mas nunca mostrado StackExchange.ready ( function () { StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' ); } ); Postar como convidado Nome O email Obrigatório, mas nunca mostrado Publique a sua resposta Descartar Ao clicar em “Publique sua resposta”, você concorda com nossos termos de serviço, política de privacidade e política de cookies Não é a resposta que você está procurando? Navegue por outras questões com a tag python pytorch tensor one-hot-encoding ou faça sua própria pergunta.